Explain Blackbox Regressors#

In this notebook we will use the interpret package to explain blackbox regressors using SHAP, Lime, MorrisSensitivity, and PartialDependence.

This notebook can be found in our examples folder on GitHub.

# install interpret if not already installed
try:
    import interpret
except ModuleNotFoundError:
    !pip install --quiet interpret pandas scikit-learn lime
import numpy as np
import pandas as pd
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from interpret import show

from interpret import set_visualize_provider
from interpret.provider import InlineProvider
set_visualize_provider(InlineProvider())

X, y = load_diabetes(return_X_y=True, as_frame=True)

seed = 42
np.random.seed(seed)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)

Train a blackbox regression system

from sklearn.ensemble import RandomForestRegressor
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline

#Blackbox system can include preprocessing, not just a regressor!
pca = PCA()
rf = RandomForestRegressor(random_state=seed)

blackbox_model = Pipeline([('pca', pca), ('rf', rf)])
blackbox_model.fit(X_train, y_train)
Pipeline(steps=[('pca', PCA()), ('rf', RandomForestRegressor(random_state=42))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Show blackbox model performance

from interpret.perf import RegressionPerf

blackbox_perf = RegressionPerf(blackbox_model).explain_perf(X_test, y_test, name='Blackbox')
show(blackbox_perf)

Local Explanations: How an individual prediction was made

from interpret.blackbox import LimeTabular

#Blackbox explainers need a predict function, and optionally a dataset
lime = LimeTabular(blackbox_model, X_train, random_state=1)

#Pick the instances to explain, optionally pass in labels if you have them
lime_local = lime.explain_local(X_test[:5], y_test[:5], name='LIME')

show(lime_local, 0)
from interpret.blackbox import ShapKernel

background_val = pd.DataFrame(np.median(X_train, axis=0).reshape(1, -1), columns=X.columns)
shap = ShapKernel(blackbox_model, background_val)
shap_local = shap.explain_local(X_test[:5], y_test[:5], name='SHAP')
show(shap_local, 0)
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 725, in start
    self.io_loop.start()
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
    self._run_once()
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
    handle._run()
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 513, in dispatch_queue
    await self.process_one()
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 502, in process_one
    await dispatch(*args)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 409, in dispatch_shell
    await result
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
    reply_content = await reply_content
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
    res = shell.run_cell(
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2961, in run_cell
    result = self._run_cell(
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3016, in _run_cell
    result = runner(coro)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3221, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3400, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_3606/2654330808.py", line 4, in <module>
    shap = ShapKernel(blackbox_model, background_val)
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/interpret/blackbox/_shap.py", line 32, in __init__
    from shap import KernelExplainer
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/__init__.py", line 4, in <module>
    from .explainers import other
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/explainers/__init__.py", line 4, in <module>
    from ._gpu_tree import GPUTreeExplainer
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/explainers/_gpu_tree.py", line 5, in <module>
    from ._tree import (
  File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/explainers/_tree.py", line 29, in <module>
    from .. import _cext
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
AttributeError: _ARRAY_API not found
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[6], line 4
      1 from interpret.blackbox import ShapKernel
      3 background_val = pd.DataFrame(np.median(X_train, axis=0).reshape(1, -1), columns=X.columns)
----> 4 shap = ShapKernel(blackbox_model, background_val)
      5 shap_local = shap.explain_local(X_test[:5], y_test[:5], name='SHAP')
      6 show(shap_local, 0)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/interpret/blackbox/_shap.py:32, in ShapKernel.__init__(self, model, data, feature_names, feature_types, **kwargs)
     21 def __init__(self, model, data, feature_names=None, feature_types=None, **kwargs):
     22     """Initializes class.
     23 
     24     Args:
   (...)
     29         **kwargs: Kwargs that will be sent to shap.KernelExplainer
     30     """
---> 32     from shap import KernelExplainer
     34     self.model = model
     35     self.feature_names = feature_names

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/__init__.py:45
     43     have_matplotlib = False
     44 if have_matplotlib:
---> 45     from . import plots
     46     from .plots._bar import bar_legacy as bar_plot
     47     from .plots._beeswarm import summary_legacy as summary_plot

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/__init__.py:6
      3 except ImportError:
      4     raise ImportError("matplotlib is not installed so plotting is not available! Run `pip install matplotlib` to fix this.")
----> 6 from ._bar import bar
      7 from ._beeswarm import beeswarm
      8 from ._benchmark import benchmark

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/_bar.py:9
      7 from ..utils import format_value, ordinal_str
      8 from ..utils._exceptions import DimensionError
----> 9 from . import colors
     10 from ._labels import labels
     11 from ._utils import (
     12     convert_ordering,
     13     dendrogram_coords,
   (...)
     16     sort_inds,
     17 )

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/__init__.py:1
----> 1 from ._colors import (
      2     blue_rgb,
      3     gray_rgb,
      4     light_blue_rgb,
      5     light_red_rgb,
      6     red_blue,
      7     red_blue_circle,
      8     red_blue_no_bounds,
      9     red_blue_transparent,
     10     red_rgb,
     11     red_transparent_blue,
     12     red_white_blue,
     13     transparent_blue,
     14     transparent_red,
     15 )
     17 __all__ = [
     18     "blue_rgb",
     19     "gray_rgb",
   (...)
     30     "transparent_red",
     31 ]

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colors.py:23
     21 red_lch = [54., 90., 0.35470565 + 2* np.pi]
     22 gray_lch = [55., 0., 0.]
---> 23 blue_rgb = lch2rgb(blue_lch)
     24 red_rgb = lch2rgb(red_lch)
     25 gray_rgb = lch2rgb(gray_lch)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colors.py:13, in lch2rgb(x)
     12 def lch2rgb(x):
---> 13     return lab2rgb(lch2lab([[x]]))[0][0]

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colorconv.py:372, in lch2lab(lch)
    346 def lch2lab(lch):
    347     """CIE-LCH to CIE-LAB color space conversion.
    348     LCH is the cylindrical representation of the LAB (Cartesian) colorspace
    349     Parameters
   (...)
    370     >>> img_lab2 = lch2lab(img_lch)
    371     """
--> 372     lch = _prepare_lab_array(lch)
    374     c, h = lch[..., 1], lch[..., 2]
    375     lch[..., 1], lch[..., 2] = c * np.cos(h), c * np.sin(h)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colorconv.py:387, in _prepare_lab_array(arr)
    385 if shape[-1] < 3:
    386     raise ValueError('Input array has less than 3 color channels')
--> 387 return img_as_float(arr, force_copy=True)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colorconv.py:993, in img_as_float(image, force_copy)
    972 def img_as_float(image, force_copy=False):
    973     """Convert an image to floating point format.
    974     This function is similar to `img_as_float64`, but will not convert
    975     lower-precision floating point arrays to `float64`.
   (...)
    991     and can be outside the ranges [0.0, 1.0] or [-1.0, 1.0].
    992     """
--> 993     return convert(image, np.floating, force_copy)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colorconv.py:819, in convert(image, dtype, force_copy, uniform)
    808 itemsize_out = dtypeobj_out.itemsize
    810 # Below, we do an `issubdtype` check.  Its purpose is to find out
    811 # whether we can get away without doing any image conversion.  This happens
    812 # when:
   (...)
    816 #   is a subclass of that type (e.g. `np.floating` will allow
    817 #   `float32` and `float64` arrays through)
--> 819 if np.issubdtype(dtype_in, np.obj2sctype(dtype)):
    820     if force_copy:
    821         image = image.copy()

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.obj2sctype` was removed in the NumPy 2.0 release. Use `np.dtype(obj).type` instead.

Global Explanations: How the model behaves overall

from interpret.blackbox import MorrisSensitivity

sensitivity = MorrisSensitivity(blackbox_model, X_train)
sensitivity_global = sensitivity.explain_global(name="Global Sensitivity")

show(sensitivity_global)
from interpret.blackbox import PartialDependence

pdp = PartialDependence(blackbox_model, X_train)
pdp_global = pdp.explain_global(name='Partial Dependence')

show(pdp_global, 0)